{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Auto-Sklearn cannot be imported.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x10be48570>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gurobipy as gp\n",
    "from gurobipy import GRB\n",
    "import numpy as np\n",
    "import pyepo\n",
    "from pyepo.model.grb import optGrbModel\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "import time as time\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR\n",
    "from ModelsKnapSack import KnapSackNet, ValPredictNet \n",
    "\n",
    "# Set the random seed\n",
    "seed = 42\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate data\n",
    "num_data = 10000 # number of data\n",
    "num_feat = 5 # size of feature\n",
    "num_item = 100 # number of items\n",
    "num_constraints = 2 # number of knapsacks\n",
    "weights_numpy, contexts_numpy, costs_numpy = pyepo.data.knapsack.genData(num_data, num_feat, num_item,\n",
    "                                                dim=num_constraints, deg=1, noise_width=0.5, seed=135)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'mps'\n",
    "weights = torch.Tensor(weights_numpy).to(device)\n",
    "# print(weights)\n",
    "## NB: Weights here refer to the weights of the items in the knapsack problem.\n",
    "# They become part of the constraints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/danielmckenzie/My-Drive/Research/Fixed_Point_Networks/Diff-Opt-Over-Polytopes-Project/SPO_with_DYS/Knapsack_Testing/dYS_opt_net.py:29: UserWarning: The operator 'aten::linalg_svd' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)\n",
      "  U, s, VT = torch.linalg.svd(self.A, full_matrices=False)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "KnapSackNet(\n",
       "  (fc_1): Linear(in_features=5, out_features=50, bias=True)\n",
       "  (fc_2): Linear(in_features=50, out_features=50, bias=True)\n",
       "  (fc_3): Linear(in_features=50, out_features=100, bias=True)\n",
       "  (leaky_relu): LeakyReLU(negative_slope=0.1)\n",
       "  (relu): ReLU()\n",
       "  (dropout): Dropout(p=0.3, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "capacities_numpy = 20*torch.ones(2)\n",
    "capacities = capacities_numpy.to(device)\n",
    "capacities_DYS = 18*torch.ones(2, device=device)\n",
    "DYS_net = KnapSackNet(weights, capacities_DYS, num_constraints, num_item, num_feat, device=device)\n",
    "DYS_net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ans = DYS_net(contexts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(ans)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Setting up the Training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: Package(s) not found: sklearn\u001b[0m\u001b[33m\r\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "! pip show sklearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split train test data\n",
    "from sklearn.model_selection import train_test_split\n",
    "d_train, d_test, w_train, w_test = train_test_split(contexts_numpy, costs_numpy, test_size=100, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restricted license - for non-production use only - expires 2024-10-28\n"
     ]
    }
   ],
   "source": [
    "caps = capacities.cpu() #[20] * 2 # capacity\n",
    "optmodel = pyepo.model.grb.knapsackModel(weights_numpy, caps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimizing for optDataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 47%|█████████████████▊                    | 4631/9900 [00:19<00:20, 257.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Interrupt request received\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 53%|████████████████████▎                 | 5280/9900 [00:21<00:17, 259.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Interrupt request received\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████| 9900/9900 [00:41<00:00, 241.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimizing for optDataset...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 100/100 [00:00<00:00, 233.29it/s]\n"
     ]
    }
   ],
   "source": [
    "# get optDataset\n",
    "dataset_train = pyepo.data.dataset.optDataset(optmodel, d_train, w_train)\n",
    "dataset_test = pyepo.data.dataset.optDataset(optmodel, d_test, w_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set data loader\n",
    "from torch.utils.data import DataLoader\n",
    "batch_size = 256\n",
    "loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)\n",
    "loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Set-up\n",
    "Hyperparameters, initialize optimizer, initialize arrays etc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss_hist = []\n",
    "fmt = '[{:4d}/{:4d}]: train loss = {:7.3e} | test_loss = {:7.3e} ]'\n",
    "train_start_time = time.time()\n",
    "epoch = 0\n",
    "max_epochs = 10\n",
    "train_loss_ave = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Testing Regret loss function\n",
    "# Make eval mode that rounds before computing cost.\n",
    "class RegretLoss(nn.Module):\n",
    "    def __init__(self, num_item, num_constraints, device):\n",
    "        super(RegretLoss, self).__init__()\n",
    "        self.device = device\n",
    "        self.num_items = num_item  # number of items, i.e. non-dummy variables\n",
    "        self.num_constraints = num_constraints # number of constraints, i.e. dummy variables\n",
    "    \n",
    "    def forward(self, w_true, x_pred, x_opt, opt_value, eval_mode=False):\n",
    "        '''\n",
    "          d is (batch of) contexts, w_true is (batch of) true \n",
    "          cost vectors.\n",
    "        '''\n",
    "        if eval_mode:\n",
    "            regret = opt_value - (w_true*torch.round(x_pred[:,:-(self.num_constraints + self.num_items)])).sum(dim=-1)\n",
    "            regret = torch.mean(regret)/torch.mean(opt_value) #normalized regret\n",
    "        else:\n",
    "            regret = opt_value - (w_true*x_pred[:,:-(self.num_constraints + self.num_items)]).sum(dim=-1)\n",
    "            regret = torch.mean(regret)\n",
    "        return regret\n",
    "    \n",
    "    # cost = torch.matmul(w_true.view(-1, self.num_items), x_pred[:,:-self.num_constraints]) # leave out dummy variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(DYS_net.parameters(), lr=1e-3, weight_decay = 5e-4)\n",
    "criterion = RegretLoss(num_item, num_constraints, device=device)\n",
    "scheduler = ReduceLROnPlateau(optimizer, 'min')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:  0 test loss is  0.9049567580223083\n",
      "epoch:  1 test loss is  0.30895093083381653\n",
      "epoch:  2 test loss is  0.2976419925689697\n",
      "epoch:  3 test loss is  0.25240617990493774\n",
      "epoch:  4 test loss is  0.25336864590644836\n",
      "epoch:  5 test loss is  0.2569778561592102\n",
      "epoch:  6 test loss is  0.2613089680671692\n",
      "epoch:  7 test loss is  0.25240617990493774\n",
      "epoch:  8 test loss is  0.2649182081222534\n",
      "epoch:  9 test loss is  0.2666025161743164\n",
      "epoch:  10 test loss is  0.2930702567100525\n"
     ]
    }
   ],
   "source": [
    "epoch = 0\n",
    "while epoch <= max_epochs:\n",
    "    for d_batch, w_batch, opt_sol, opt_value in loader_train:\n",
    "        d_batch = d_batch.to(device)\n",
    "        w_batch = w_batch.to(device)\n",
    "        opt_sol = opt_sol.to(device)\n",
    "        opt_value = opt_value.to(device)\n",
    "        DYS_net.train()\n",
    "        optimizer.zero_grad()\n",
    "        predicted = DYS_net(d_batch)\n",
    "        loss = criterion(w_batch, predicted, opt_sol, opt_value)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss_ave = 0.95*train_loss_ave + 0.05*loss.item()\n",
    "        \n",
    "    DYS_net.eval()\n",
    "    test_loss = 0\n",
    "    for d_batch, w_batch, opt_sol, opt_value in loader_test:\n",
    "        d_batch = d_batch.to(device)\n",
    "        w_batch = w_batch.to(device)\n",
    "        opt_sol = opt_sol.to(device)\n",
    "        opt_value = opt_value.to(device)\n",
    "        predicted = DYS_net(d_batch)\n",
    "        test_loss += max(0., criterion(w_batch, predicted, opt_sol, opt_value, eval_mode=True).item())\n",
    "    \n",
    "    # print(predicted)\n",
    "    scheduler.step(test_loss)\n",
    "    test_loss_hist.append(test_loss)\n",
    "    print('epoch: ', epoch, 'test loss is ', test_loss)\n",
    "    epoch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([29., 29., 26., 26., 58., 24., 34., 24., 38., 45., 31., 14., 33., 20.,\n",
      "        18., 43., 25., 29., 22., 25., 27., 25., 26., 25., 28., 48., 40., 29.,\n",
      "        61., 23., 30., 33., 35., 41., 30., 35., 43., 31., 25., 30., 21., 39.,\n",
      "        29., 25., 20., 24., 30., 35., 24., 27., 28., 21., 37., 29., 22., 22.,\n",
      "        24., 34., 23., 30., 15., 28., 24., 27., 22., 30., 42., 23., 31., 19.,\n",
      "        29., 32., 33., 42., 26., 27., 32., 25., 35., 32., 22., 19., 29., 36.,\n",
      "        34., 15., 32., 40., 25., 30., 29., 27., 31., 19., 43., 27., 33., 26.,\n",
      "        22., 23.], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "print((w_batch*torch.round(predicted[:,:-(num_item + num_constraints)])).sum(dim=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[38.],\n",
      "        [42.],\n",
      "        [33.],\n",
      "        [40.],\n",
      "        [53.],\n",
      "        [39.],\n",
      "        [42.],\n",
      "        [32.],\n",
      "        [55.],\n",
      "        [50.],\n",
      "        [42.],\n",
      "        [33.],\n",
      "        [41.],\n",
      "        [40.],\n",
      "        [36.],\n",
      "        [50.],\n",
      "        [41.],\n",
      "        [46.],\n",
      "        [35.],\n",
      "        [37.],\n",
      "        [39.],\n",
      "        [40.],\n",
      "        [36.],\n",
      "        [37.],\n",
      "        [47.],\n",
      "        [59.],\n",
      "        [60.],\n",
      "        [42.],\n",
      "        [56.],\n",
      "        [32.],\n",
      "        [44.],\n",
      "        [45.],\n",
      "        [44.],\n",
      "        [50.],\n",
      "        [41.],\n",
      "        [53.],\n",
      "        [45.],\n",
      "        [47.],\n",
      "        [37.],\n",
      "        [42.],\n",
      "        [35.],\n",
      "        [45.],\n",
      "        [48.],\n",
      "        [36.],\n",
      "        [38.],\n",
      "        [40.],\n",
      "        [45.],\n",
      "        [50.],\n",
      "        [42.],\n",
      "        [38.],\n",
      "        [34.],\n",
      "        [36.],\n",
      "        [49.],\n",
      "        [42.],\n",
      "        [38.],\n",
      "        [34.],\n",
      "        [36.],\n",
      "        [43.],\n",
      "        [37.],\n",
      "        [45.],\n",
      "        [36.],\n",
      "        [45.],\n",
      "        [35.],\n",
      "        [35.],\n",
      "        [38.],\n",
      "        [40.],\n",
      "        [51.],\n",
      "        [38.],\n",
      "        [46.],\n",
      "        [34.],\n",
      "        [40.],\n",
      "        [41.],\n",
      "        [48.],\n",
      "        [46.],\n",
      "        [38.],\n",
      "        [39.],\n",
      "        [48.],\n",
      "        [40.],\n",
      "        [40.],\n",
      "        [44.],\n",
      "        [41.],\n",
      "        [37.],\n",
      "        [42.],\n",
      "        [44.],\n",
      "        [44.],\n",
      "        [32.],\n",
      "        [42.],\n",
      "        [44.],\n",
      "        [43.],\n",
      "        [44.],\n",
      "        [44.],\n",
      "        [38.],\n",
      "        [44.],\n",
      "        [35.],\n",
      "        [52.],\n",
      "        [38.],\n",
      "        [39.],\n",
      "        [35.],\n",
      "        [32.],\n",
      "        [32.]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "print(opt_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[19.3500, 19.3500, 19.3500, 19.3500, 24.2200, 19.3500, 19.3500, 19.3500,\n",
      "         14.4800, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 14.4800, 19.3500, 24.2200, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 19.3500, 19.3500, 24.2200, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500,\n",
      "         14.4800, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500, 19.3500,\n",
      "         19.3500, 19.3500, 19.3500, 19.3500],\n",
      "        [18.2600, 18.2600, 18.2600, 18.2600, 21.7300, 18.2600, 18.2600, 18.2600,\n",
      "         15.1000, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 15.1000, 18.2600, 21.7300, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 18.2600, 18.2600, 21.7300, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600,\n",
      "         15.1000, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600, 18.2600,\n",
      "         18.2600, 18.2600, 18.2600, 18.2600]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "print(torch.matmul(weights, torch.t(torch.round(predicted[:,:-(num_item + num_constraints)]))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[5., 7., 6.,  ..., 7., 4., 8.],\n",
      "        [4., 4., 8.,  ..., 5., 7., 8.],\n",
      "        [4., 6., 4.,  ..., 4., 7., 8.],\n",
      "        ...,\n",
      "        [7., 3., 6.,  ..., 5., 5., 4.],\n",
      "        [3., 5., 7.,  ..., 7., 6., 4.],\n",
      "        [6., 6., 6.,  ..., 4., 6., 7.]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "print(w_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[4.5900, 4.8700, 5.1900, 7.5900, 6.8000, 4.9700, 4.8400, 3.2200, 7.9000,\n",
      "         3.6200, 3.1200, 6.8100, 7.0900, 3.2300, 4.4700, 5.2300, 6.1100, 5.2300,\n",
      "         4.2700, 6.6600, 5.2100, 3.7200, 7.8300, 5.6400, 7.5600, 3.2800, 4.3100,\n",
      "         3.8500, 6.4300, 4.2600, 7.5600, 6.4000, 6.0900, 6.4000, 4.5000, 6.3100,\n",
      "         6.3500, 3.0700, 6.0300, 7.3900, 3.4800, 5.5700, 6.2700, 4.5500, 5.6400,\n",
      "         5.6600, 3.2700, 6.6800, 4.2000, 5.6100, 6.1300, 7.8800, 4.6100, 4.9200,\n",
      "         6.0900, 7.9500, 7.1100, 4.0200, 3.0700, 7.8400, 6.2900, 5.7800, 5.5700,\n",
      "         5.6500, 7.4100, 5.2100, 3.4200, 6.4800, 5.7700, 7.6700, 7.3200, 3.3100,\n",
      "         4.3600, 4.7500, 3.3800, 3.4400, 6.3000, 4.8500, 7.0900, 5.7000, 4.0500,\n",
      "         6.0400, 6.5100, 3.4100, 7.9400, 6.0500, 7.1800, 3.6400, 4.3500, 4.4500,\n",
      "         4.2700, 4.8700, 5.7400, 3.1600, 5.4100, 3.8100, 3.5100, 5.9800, 5.7900,\n",
      "         6.6200],\n",
      "        [5.1100, 3.1600, 7.0100, 6.4800, 6.3600, 4.1200, 7.9000, 7.6900, 3.6800,\n",
      "         4.7300, 5.8500, 5.0300, 3.6900, 7.1300, 7.1100, 5.7200, 5.8000, 4.5900,\n",
      "         6.8500, 4.1400, 6.1800, 4.2800, 5.3800, 5.7900, 6.7800, 3.7900, 4.6100,\n",
      "         3.0600, 7.7800, 3.9400, 6.0600, 7.2200, 5.1800, 3.4400, 3.9300, 5.5600,\n",
      "         3.5600, 6.5900, 5.0000, 5.2400, 7.9400, 6.7500, 6.3700, 7.6400, 3.4000,\n",
      "         3.8700, 3.3400, 4.9500, 3.5300, 4.6400, 7.2800, 6.9000, 5.1600, 4.4900,\n",
      "         5.0400, 7.8200, 7.4100, 6.9200, 7.4500, 7.2900, 6.3100, 6.6000, 4.6600,\n",
      "         4.0700, 7.2500, 5.5800, 5.3300, 7.8200, 7.3700, 6.2300, 4.1200, 7.2400,\n",
      "         6.1000, 6.2600, 6.0900, 4.8700, 6.1200, 7.2000, 7.4800, 4.8700, 6.6600,\n",
      "         5.1000, 6.1700, 4.6900, 4.4100, 4.0400, 6.0000, 3.7100, 6.8200, 5.1500,\n",
      "         6.9500, 3.4700, 4.1600, 5.1700, 7.9700, 6.0800, 3.8700, 7.0300, 3.8800,\n",
      "         3.1300]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "print(weights) # S matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "# print(torch.round(predicted[:,:-12]) - torch.round(predicted[0,:-12]))\n",
    "print(torch.round((predicted[:,50:60])))#:-(num_constraints + num_item)])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],\n",
      "        [ 0., -1.,  0.,  ...,  0.,  0.,  0.],\n",
      "        [ 0., -1.,  0.,  ...,  0.,  0.,  0.],\n",
      "        ...,\n",
      "        [ 0., -1.,  0.,  ...,  0.,  0.,  0.],\n",
      "        [ 0., -1.,  0.,  ...,  0.,  0.,  0.],\n",
      "        [ 0., -1.,  0.,  ...,  0.,  0.,  0.]], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "print(opt_sol[:15,:] - opt_sol[0,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x2c13ea950>]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAvAUlEQVR4nO3dfVRU94H/8c/MwAzPoKCgiA4aE2MewEKlxjx1S8s2Obam53TtNo0eurFns3rWhN/uNiZRdm0j2/bEn9usjYkn7vakm8Zta5p2k9q6bJOsv9iQguahsRqrCBp58oHRQQFn7u8PmIFRngZm5g4z79c59wCXe2e+cBLm4/d+P3cshmEYAgAAMInV7AEAAID4RhgBAACmIowAAABTEUYAAICpCCMAAMBUhBEAAGAqwggAADAVYQQAAJgqwewBjIXX69XHH3+s9PR0WSwWs4cDAADGwDAMXbhwQTNnzpTVOvz8x6QIIx9//LEKCgrMHgYAABiH5uZmzZo1a9jvT4owkp6eLqnvh8nIyDB5NAAAYCxcLpcKCgr8r+PDmRRhxHdpJiMjgzACAMAkM9oSCxawAgAAUxFGAACAqQgjAADAVIQRAABgKsIIAAAwFWEEAACYijACAABMRRgBAACmIowAAABTEUYAAICpCCMAAMBUhBEAAGCquA4jP3yrUd/86Xtq7HCbPRQAAOJWXIeR3QdOadfvm3XotMvsoQAAELfGFUa2bdsmp9OppKQklZWVqa6ubthje3t7tWnTJs2bN09JSUkqKirSnj17xj3gUCrMTpEkNZ7pMnkkAADEr6DDyK5du1RVVaXq6mo1NDSoqKhIFRUVamtrG/L4J554Qs8++6yefvppffjhh/rrv/5r3XfffTpw4MCEBz9Rc7JTJYnLNAAAmCjoMLJlyxatXr1alZWVWrhwobZv366UlBTt3LlzyONfeOEFPfbYY7rnnns0d+5cPfTQQ7rnnnv01FNPTXjwE1WY0x9GzhBGAAAwS1BhpKenR/X19SovLx94AKtV5eXl2r9//5DndHd3KykpKWBfcnKy9u3bN+zzdHd3y+VyBWzh4CSMAABguqDCSEdHhzwej3JzcwP25+bmqqWlZchzKioqtGXLFn300Ufyer3au3evdu/erdOnTw/7PDU1NcrMzPRvBQUFwQxzzJz9a0ZaXd3q6rkSlucAAAAjC3ub5l/+5V80f/58LViwQHa7XWvXrlVlZaWs1uGfev369ers7PRvzc3NYRlbVopdWSmJkqQTLGIFAMAUQYWRnJwc2Ww2tba2BuxvbW1VXl7ekOdMmzZNP//5z+V2u3XixAn98Y9/VFpamubOnTvs8zgcDmVkZARs4eJbxHqCSzUAAJgiqDBit9tVUlKi2tpa/z6v16va2lotWbJkxHOTkpKUn5+vK1eu6Gc/+5m++MUvjm/EIear9x7vYGYEAAAzJAR7QlVVlVatWqXS0lItXrxYW7duldvtVmVlpSRp5cqVys/PV01NjSTp7bff1qlTp1RcXKxTp07pH//xH+X1evUP//APof1Jxol6LwAA5go6jKxYsULt7e3auHGjWlpaVFxcrD179vgXtTY1NQWsB7l8+bKeeOIJHTt2TGlpabrnnnv0wgsvKCsrK2Q/xERQ7wUAwFwWwzAMswcxGpfLpczMTHV2doZ8/ciBpnO67wdvKTfDobcfKx/9BAAAMCZjff2O6/emkQZmRqj3AgBgjrgPI9R7AQAwV9yHEYl6LwAAZiKMiHovAABmIoyImREAAMxEGNHAItbj3GsEAICII4xImtN/mYZ7jQAAEHmEEVHvBQDATIQR9dV7M5Op9wIAYAbCSD9nDotYAQAwA2GkH/VeAADMQRjpR70XAABzEEb6Ue8FAMAchJF+vnovC1gBAIgswkg/38xIi+uyLvV4TB4NAADxgzDSb3C9l5ufAQAQOYSRQaj3AgAQeYSRQZzUewEAiDjCyCBO6r0AAEQcYWQQ6r0AAEQeYWQQ6r0AAEQeYWQQ6r0AAEQeYWSQgHfvPculGgAAIoEwchVfvbeRdSMAAEQEYeQq1HsBAIgswshVqPcCABBZhJGrOHN8MyOEEQAAIoEwcpWBmREu0wAAEAmEkatQ7wUAILIII1eh3gsAQGQRRoZAvRcAgMghjAzBV+9tZN0IAABhRxgZgm8RKzMjAACEH2FkCNR7AQCIHMLIEKj3AgAQOYSRIfjCCPVeAADCjzAyhCmp1HsBAIgUwsgwqPcCABAZhJFhUO8FACAyCCPDoN4LAEBkjCuMbNu2TU6nU0lJSSorK1NdXd2Ix2/dulU33HCDkpOTVVBQoEceeUSXL18e14AjxVfvbTxDGAEAIJyCDiO7du1SVVWVqqur1dDQoKKiIlVUVKitrW3I41988UU9+uijqq6u1qFDh/T8889r165deuyxxyY8+HAamBnhMg0AAOEUdBjZsmWLVq9ercrKSi1cuFDbt29XSkqKdu7cOeTxb731lpYuXaqvfvWrcjqd+tznPqe//Mu/HHU2xWzUewEAiIygwkhPT4/q6+tVXl4+8ABWq8rLy7V///4hz7nttttUX1/vDx/Hjh3Ta6+9pnvuuWfY5+nu7pbL5QrYIo16LwAAkRFUGOno6JDH41Fubm7A/tzcXLW0tAx5zle/+lVt2rRJt99+uxITEzVv3jzdfffdI16mqampUWZmpn8rKCgIZpgh42/UsIgVAICwCXub5vXXX9fmzZv1gx/8QA0NDdq9e7deffVVfetb3xr2nPXr16uzs9O/NTc3h3uYQ/Lfa4R6LwAAYZMQzME5OTmy2WxqbW0N2N/a2qq8vLwhz9mwYYMeeOABPfjgg5KkW265RW63W9/4xjf0+OOPy2q9Ng85HA45HI5ghhYWc6j3AgAQdkHNjNjtdpWUlKi2tta/z+v1qra2VkuWLBnynK6urmsCh81mkyQZhhHseCOqkHovAABhF9TMiCRVVVVp1apVKi0t1eLFi7V161a53W5VVlZKklauXKn8/HzV1NRIkpYtW6YtW7Zo0aJFKisr09GjR7VhwwYtW7bMH0qiFfVeAADCL+gwsmLFCrW3t2vjxo1qaWlRcXGx9uzZ41/U2tTUFDAT8sQTT8hiseiJJ57QqVOnNG3aNC1btkxPPvlk6H6KMLm63ptsj+7wBADAZGQxov1aiSSXy6XMzEx1dnYqIyMjos9d9E+/UeelXu15+A4tyIvscwMAMJmN9fWb96YZBfVeAADCizAyCuq9AACEF2FkFNR7AQAIL8LIKKj3AgAQXoSRUcyh3gsAQFgRRkZRyLv3AgAQVoSRUfDuvQAAhBdhZAwG6r1cqgEAINQII2MwUO9lZgQAgFAjjIwB9V4AAMKHMDIG1HsBAAgfwsgYUO8FACB8CCNjQL0XAIDwIYyMQVZKojKSEiRR7wUAINQII2NgsVhUmMOlGgAAwoEwMkbUewEACA/CyBj5FrGeIIwAABBShJEx8tV7j3OvEQAAQoowMkbUewEACA/CyBhR7wUAIDwII2NEvRcAgPAgjIwR9V4AAMKDMBIE/7oRGjUAAIQMYSQIvnuNUO8FACB0CCNBoN4LAEDoEUaCMHDjM9aMAAAQKoSRIPjqvac7qfcCABAqhJEgDK73Np1ldgQAgFAgjARhcL2XdSMAAIQGYSRI1HsBAAgtwkiQqPcCABBahJEgObOp9wIAEEqEkSANzIywgBUAgFAgjASJei8AAKFFGAkS9V4AAEKLMBIk6r0AAIQWYWQcBm4LTxgBAGCiCCPj4FvEyr1GAACYOMLIOFDvBQAgdAgj40C9FwCA0BlXGNm2bZucTqeSkpJUVlamurq6YY+9++67ZbFYrtnuvffecQ/abE7qvQAAhEzQYWTXrl2qqqpSdXW1GhoaVFRUpIqKCrW1tQ15/O7du3X69Gn/9sEHH8hms+nLX/7yhAdvlinUewEACJmgw8iWLVu0evVqVVZWauHChdq+fbtSUlK0c+fOIY+fOnWq8vLy/NvevXuVkpIyqcMI9V4AAEInqDDS09Oj+vp6lZeXDzyA1ary8nLt379/TI/x/PPP6ytf+YpSU1OHPaa7u1sulytgizbUewEACI2gwkhHR4c8Ho9yc3MD9ufm5qqlpWXU8+vq6vTBBx/owQcfHPG4mpoaZWZm+reCgoJghhkR1HsBAAiNiLZpnn/+ed1yyy1avHjxiMetX79enZ2d/q25uTlCIxw7X723sYM1IwAATERCMAfn5OTIZrOptbU1YH9ra6vy8vJGPNftduull17Spk2bRn0eh8Mhh8MRzNAijpkRAABCI6iZEbvdrpKSEtXW1vr3eb1e1dbWasmSJSOe+5Of/ETd3d362te+Nr6RRhnqvQAAhEbQl2mqqqq0Y8cO/fCHP9ShQ4f00EMPye12q7KyUpK0cuVKrV+//prznn/+eS1fvlzZ2dkTH3UUoN4LAEBoBHWZRpJWrFih9vZ2bdy4US0tLSouLtaePXv8i1qbmppktQZmnMOHD2vfvn36zW9+E5pRRwGLxSJnTqreO9mp4x1u3ZCXbvaQAACYlIIOI5K0du1arV27dsjvvf7669fsu+GGG2QYxnieKqo5s/vCCPVeAADGj/emmQAWsQIAMHGEkQmg3gsAwMQRRiaAmREAACaOMDIBg+u9l3up9wIAMB6EkQkYXO89cYZLNQAAjAdhZAJ89V6Jd+8FAGC8CCMT5OTdewEAmBDCyAT5GzWEEQAAxoUwMkH+Rg31XgAAxoUwMkHUewEAmBjCyARR7wUAYGIIIxNEvRcAgIkhjEzQ4Hovl2oAAAgeYSQEfJdqGrnXCAAAQSOMhAD1XgAAxo8wEgLUewEAGD/CSAjMyWbNCAAA40UYCYHCHOq9AACMF2EkBKj3AgAwfoSREKDeCwDA+BFGQoR6LwAA40MYCZGBei+XaQAACAZhJEQG6r3MjAAAEAzCSIhQ7wUAYHwIIyFCvRcAgPEhjITIlJREpVPvBQAgaISRELFYLP7ZES7VAAAwdoSREKLeCwBA8AgjIUS9FwCA4BFGQoh6LwAAwSOMhJCv3nuCNSMAAIwZYSSEfAtYP6beCwDAmBFGQoh6LwAAwSOMhBD1XgAAgkcYCbE51HsBAAgKYSTECqn3AgAQFMJIiDEzAgBAcAgjIea71wj1XgAAxoYwEmLUewEACA5hJMQG13ubzrJuBACA0YwrjGzbtk1Op1NJSUkqKytTXV3diMefP39ea9as0YwZM+RwOHT99dfrtddeG9eAo93geu9x1o0AADCqoMPIrl27VFVVperqajU0NKioqEgVFRVqa2sb8vienh599rOfVWNjo37605/q8OHD2rFjh/Lz8yc8+GjFIlYAAMYuIdgTtmzZotWrV6uyslKStH37dr366qvauXOnHn300WuO37lzp86ePau33npLiYmJkiSn0zmxUUc56r0AAIxdUDMjPT09qq+vV3l5+cADWK0qLy/X/v37hzznF7/4hZYsWaI1a9YoNzdXN998szZv3iyPZ/jFnd3d3XK5XAHbZMLMCAAAYxdUGOno6JDH41Fubm7A/tzcXLW0tAx5zrFjx/TTn/5UHo9Hr732mjZs2KCnnnpK3/72t4d9npqaGmVmZvq3goKCYIZpOuq9AACMXdjbNF6vV9OnT9dzzz2nkpISrVixQo8//ri2b98+7Dnr169XZ2enf2tubg73MEPK2X+ZhnovAACjC2rNSE5Ojmw2m1pbWwP2t7a2Ki8vb8hzZsyYocTERNlsNv++G2+8US0tLerp6ZHdbr/mHIfDIYfDEczQosrUVLvSkxJ04fIVNZ3t0vW56WYPCQCAqBXUzIjdbldJSYlqa2v9+7xer2pra7VkyZIhz1m6dKmOHj0qr9fr33fkyBHNmDFjyCASC6j3AgAwdkFfpqmqqtKOHTv0wx/+UIcOHdJDDz0kt9vtb9esXLlS69ev9x//0EMP6ezZs1q3bp2OHDmiV199VZs3b9aaNWtC91NEId8iVtaNAAAwsqCrvStWrFB7e7s2btyolpYWFRcXa8+ePf5FrU1NTbJaBzJOQUGBfv3rX+uRRx7Rrbfeqvz8fK1bt07f/OY3Q/dTRCFfvfd4B/VeAABGYjEMwzB7EKNxuVzKzMxUZ2enMjIyzB7OmPys/qT+z0/e1ZK52frxNz5l9nAAAIi4sb5+8940YUK9FwCAsSGMhAn1XgAAxoYwEia+eq/Eu/cCADASwkiYWCwWObOp9wIAMBrCSBixbgQAgNERRsKIei8AAKMjjIQRNz4DAGB0hJEw8l2maWTNCAAAwyKMhBH1XgAARkcYCSPqvQAAjI4wEkbUewEAGB1hJMyo9wIAMDLCSJg5qfcCADAiwkiYOan3AgAwIsJImFHvBQBgZISRMKPeCwDAyAgjYUa9FwCAkRFGwmxwvZdLNQAAXIswEgH+dSMsYgUA4BqEkQig3gsAwPAIIxFAvRcAgOERRiLAmdM3M8KaEQAArkUYiQDfzAj1XgAArkUYiQDqvQAADI8wEgHUewEAGB5hJEKo9wIAMDTCSIT46r2NZ7hMAwDAYISRCOEyDQAAQyOMRAj1XgAAhkYYiRDqvQAADI0wEiFTU+1Kd1DvBQDgaoSRCLFYLAONGi7VAADgRxiJIOq9AABcizASQdR7AQC4FmEkgqj3AgBwLcJIBPnqvSeYGQEAwI8wEkED9d5L1HsBAOhHGIkgX73XMKj3AgDgQxiJIOq9AABcizASYXP8jRrCCAAA0jjDyLZt2+R0OpWUlKSysjLV1dUNe+y///u/y2KxBGxJSUnjHvBkV+i/1wiXaQAAkMYRRnbt2qWqqipVV1eroaFBRUVFqqioUFtb27DnZGRk6PTp0/7txIkTExr0ZEa9FwCAQEGHkS1btmj16tWqrKzUwoULtX37dqWkpGjnzp3DnmOxWJSXl+ffcnNzJzToyYx6LwAAgYIKIz09Paqvr1d5efnAA1itKi8v1/79+4c97+LFi5ozZ44KCgr0xS9+UX/4wx/GP+JJjnovAACBggojHR0d8ng818xs5ObmqqWlZchzbrjhBu3cuVOvvPKKfvSjH8nr9eq2227TyZMnh32e7u5uuVyugC1WDK73NlPvBQAg/G2aJUuWaOXKlSouLtZdd92l3bt3a9q0aXr22WeHPaempkaZmZn+raCgINzDjJjB9d7jrBsBACC4MJKTkyObzabW1taA/a2trcrLyxvTYyQmJmrRokU6evTosMesX79enZ2d/q25uTmYYUY96r0AAAwIKozY7XaVlJSotrbWv8/r9aq2tlZLliwZ02N4PB69//77mjFjxrDHOBwOZWRkBGyxhHovAAADEoI9oaqqSqtWrVJpaakWL16srVu3yu12q7KyUpK0cuVK5efnq6amRpK0adMmfepTn9J1112n8+fP63vf+55OnDihBx98MLQ/ySQyh3ovAAB+QYeRFStWqL29XRs3blRLS4uKi4u1Z88e/6LWpqYmWa0DEy7nzp3T6tWr1dLSoilTpqikpERvvfWWFi5cGLqfYpIppN4LAICfxTAMw+xBjMblcikzM1OdnZ0xccnmzMVulXz7v2WxSIc2/bmSEm1mDwkAgJAb6+s3701jAuq9AAAMIIyYgHovAAADCCMm8dV7WTcCAIh3hBGT+Oq9x7nXCAAgzhFGTEK9FwCAPoQRk1DvBQCgD2HEJHN4914AACQRRkyTTb0XAABJhBHTUO8FAKAPYcRE1HsBACCMmIp6LwAAhBFT+RaxniCMAADiGGHERL56b2MHl2kAAPGLMGIi6r0AABBGTEW9FwAAwoipLBaL5vRfqqHeCwCIV4QRkzn9i1iZGQEAxCfCiMmo9wIA4h1hxGTUewEA8Y4wYjLqvQCAeEcYMRn1XgBAvCOMmIx6LwAg3hFGTEa9FwAQ7wgjUYB6LwAgnhFGooAvjFDvBQDEI8JIFHDmUO8FAMQvwkgUcGZT7wUAxC/CSBTwzYxQ7wUAxCPCSBSg3gsAiGeEkSgwuN7bSKMGABBnCCNRwteoaeReIwCAOEMYiRLUewEA8YowEiWo9wIA4hVhJEpQ7wUAxCvCSJSg3gsAiFeEkSiRnWpXGvVeAEAcIoxECYvFIif1XgBAHCKMRBHqvQCAeEQYiSL+MEKjBgAQRwgjUcS3iJUwAgCIJ4SRKEK9FwAQj8YVRrZt2yan06mkpCSVlZWprq5uTOe99NJLslgsWr58+XieNuZR7wUAxKOgw8iuXbtUVVWl6upqNTQ0qKioSBUVFWpraxvxvMbGRv3d3/2d7rjjjnEPNtZR7wUAxKOgw8iWLVu0evVqVVZWauHChdq+fbtSUlK0c+fOYc/xeDy6//779U//9E+aO3fuhAYcy6j3AgDiUVBhpKenR/X19SovLx94AKtV5eXl2r9//7Dnbdq0SdOnT9df/dVfjel5uru75XK5ArZ4MYd6LwAgzgQVRjo6OuTxeJSbmxuwPzc3Vy0tLUOes2/fPj3//PPasWPHmJ+npqZGmZmZ/q2goCCYYU5qhdR7AQBxJqxtmgsXLuiBBx7Qjh07lJOTM+bz1q9fr87OTv/W3NwcxlFGF+q9AIB4kxDMwTk5ObLZbGptbQ3Y39raqry8vGuO/9Of/qTGxkYtW7bMv8/r9fY9cUKCDh8+rHnz5l1znsPhkMPhCGZoMYN6LwAg3gQ1M2K321VSUqLa2lr/Pq/Xq9raWi1ZsuSa4xcsWKD3339fBw8e9G9f+MIX9OlPf1oHDx6Mq8svY0W9FwAQb4KaGZGkqqoqrVq1SqWlpVq8eLG2bt0qt9utyspKSdLKlSuVn5+vmpoaJSUl6eabbw44PysrS5Ku2Y8+vnrvxe4rOnmuS9dNTzd7SAAAhFXQYWTFihVqb2/Xxo0b1dLSouLiYu3Zs8e/qLWpqUlWKzd2HS9fvfeDUy4d7yCMAABin8UwDMPsQYzG5XIpMzNTnZ2dysjIMHs4YbfmxQa9+t5pPX7PjVp9J/dlAQBMTmN9/WYKIwpR7wUAxBPCSBSa42vUEEYAAHGAMBKFCn33GqHeCwCIA4SRKES9FwAQTwgjUWjwu/eePMfsCAAgthFGotDgd+89zqUaAECMI4xEKd+7955gESsAIMYRRqKUr957vIMwAgCIbYSRKEW9FwAQLwgjUYp6LwAgXhBGopRvzQj1XgBArCOMRKmcNOq9AID4QBiJUtR7AQDxgjASxaj3AgDiAWEkilHvBQDEA8JIFPPVe0+c4TINACB2EUaimK/ey8wIACCWEUaiGPVeAEA8IIxEMeq9AIB4QBiJYhaLxb9uhHovACBWEUainDOHei8AILYRRqIc9V4AQKwjjEQ56r0AgFhHGIly1HsBALGOMBLlBtd7u69Q7wUAxB7CSJQbXO9tPsulGgBA7CGMRDnqvQCAWEcYmQSo9wIAYhlhZBJw+mdGCCMAgNhDGJkEnNm+mREu0wAAYg9hZBKg3gsAiGWEkUmAei8AIJYRRiYB6r0AgFhGGJkEBtd7G6n3AgBiDGFkkvDVexup9wIAYgxhZJKg3gsAiFWEkUmCei8AIFYRRiYJJ/VeAECMIoxMEk7qvQCAGEUYmSSo9wIAYtW4wsi2bdvkdDqVlJSksrIy1dXVDXvs7t27VVpaqqysLKWmpqq4uFgvvPDCuAccr6j3AgBiVdBhZNeuXaqqqlJ1dbUaGhpUVFSkiooKtbW1DXn81KlT9fjjj2v//v167733VFlZqcrKSv3617+e8ODjDfVeAEAsCjqMbNmyRatXr1ZlZaUWLlyo7du3KyUlRTt37hzy+Lvvvlv33XefbrzxRs2bN0/r1q3Trbfeqn379k148PHGV+8ljAAAYklQYaSnp0f19fUqLy8feACrVeXl5dq/f/+o5xuGodraWh0+fFh33nnnsMd1d3fL5XIFbBhYxMplGgBALAkqjHR0dMjj8Sg3Nzdgf25urlpaWoY9r7OzU2lpabLb7br33nv19NNP67Of/eywx9fU1CgzM9O/FRQUBDPMmEW9FwAQiyLSpklPT9fBgwf1zjvv6Mknn1RVVZVef/31YY9fv369Ojs7/Vtzc3Mkhhn1qPcCAGJRQjAH5+TkyGazqbW1NWB/a2ur8vLyhj3ParXquuuukyQVFxfr0KFDqqmp0d133z3k8Q6HQw6HI5ihxYWcNLtS7Ta5ezxqPtul66anmz0kAAAmLKgwYrfbVVJSotraWi1fvlyS5PV6VVtbq7Vr1475cbxer7q7u4MaKPrqvc6cVP3hY5ce2/2B5k1PVXaqQ1NT7cpOsysnbeDzqSl2Jdi4jQwAIPoFFUYkqaqqSqtWrVJpaakWL16srVu3yu12q7KyUpK0cuVK5efnq6amRlLf+o/S0lLNmzdP3d3deu211/TCCy/omWeeCe1PEidunZWlP3zsUl3jWdU1nh3x2Ckpif3hxKGcNHvf56m+zx39Aabv86zkRFmtlgj9FAAADAg6jKxYsULt7e3auHGjWlpaVFxcrD179vgXtTY1NclqHfgXudvt1t/8zd/o5MmTSk5O1oIFC/SjH/1IK1asCN1PEUeqly3Up2+YpvaL3TpzsUdnLnbrjLun73N3375zXT3yGtK5rl6d6+rVn9pHX/Bqs1o0JcWu7P6Zlew0R9/n/WEmOy3w83RHgiwWwgsAYOIshmEYZg9iNC6XS5mZmers7FRGRobZw4l6Hq+h810914SUvq/7Pj/r7lFH//7OS71BP0eizaLs/tmVqal9l4iyU+2ammZXztX70+xKsQedewHEGMMw5PEautK/eTyGrni98ngN9V719RWvoSseQ+lJCZo1JZnLzpPUWF+/eYWIQTarpX8GwyHljn58r8erc+4edfQHl7O+zy8O+rx//5mLPbrYfUW9HkMtrstqcV0e05jSHQn6sxun675F+br9uhz+sAARdOr8Jf2/jzrUfrFbV/pf8K94+4OBx5DH6x0UBvq+9n2/1xP4dV9IGAgMg/cN/rrv3MDjPN7x/ds3wWrR7KkpKsxJ7dumpaowu+9jXkYSs7QxgJkRBO1yr8cfTHyzK2f7P3b4Pvd9/2K3uq94A86flu7QF4pm6r5F+bppZgZ/SIAQu9zr0e+OndGbRzr05kftOtp20ewhjchqkRJsViVYLbJZLf0f+74+f6lHl3u9w56bnGiTMydVc3NS5cxJUWFOmgr7v56Sao/gT4GhjPX1mzCCsDIMQ109Hh1uvaBXDpzSL987rbPuHv/3r89N032LZmn5opmakZls4kiBycswDB1tu6g3jrTrjSPtqjt+NuAfAVaLtGj2FM2blqoEm1WJvhd728CLf8JVX9usliECgkWJNmvgMVZr//7ArxNsA49rsw06/6qvbRbLiIvnvd6+WdjjHW4d63CrscOt4/1b09muEWdbslIS5czuCyb+GZWcVDmzU5Xq4MJAJBBGEJV6PV69cbhdLx84pb2HWtXT/wfTYpGWzM3Wlz4xS39+c57S+EMBjKjzUq/+39EOvdkfQE53Bl4ynZGZpLuun6Y7r5+mpfNylJmSaNJIw6fX41Xz2S41nnHrWPtASGnscOvjzpEvIedmOPov+6SpcNCMyuypKbIncBk5VAgjiHqdl3r1q/dPa/eBU6o7PlBTTkq0quKmPNaXAIN4vIbeO3nef+nlQNM5DZ4UsCdYVVY4VXddP013XT9N101Pi+tLoJd6PGo8MxBQBm+DZ2evZrVIBVNT5Mzum0WZ2z+bUpiTqpmZydwCIUiEEUwqzWe79PMDp/TygVM6Nui9d3LSHPpiMetLEJ9aXZf9Mx/7jnbofFdg8+266Wm6c/403Xl9jsoKs5Vst5k00snlfFdP3wzKGbeOt/dd/vHNqLh7hn+rDXuCVc7slCFnVHLS7Px9GgJhBJOSYRh692SnXm44yfoSxJ3uKx79vvGcP4D8seVCwPfTkxJ0+3U5urP/8kt+Fv8fhJJhGGq/0B0QTnyfnzjjVq9n+JfLdEeCCqf1rUeZPz1NRQVZKirIUmZy7F0eCwZhBJPeaOtL7luUr8/fMoP1JZi0DMNQ45kuvXG4TW9+1KH9fzqjS70D/zK3WKRb8zN1Z/+ll+KCLC5bmsTjNXTq3CUdP+PW8faLAwtqz7h18twlDfdKOndaqopnZal4dpaKZmXpxhkZcbUmhTCCmDLS+pLPLczTfZ/I1x2sL8EkcLH7it462qE3jrTrzY/a1Xz2UsD3p6U7/Jde7pg/TVOpp0a9y719b17qm0X58GOX3j15XifOdF1zrN1m1cKZGSouyPJvc7JTYvYSD2EEMYv1JZhMvF5DH552+Wu3DSfO6cqglaeJNotK50zVXTdM053zp+nGGen8txsjzrp79O7J8zrYdL7vY/P5a9b9SH0V5KJZA+GkqCArZkIoYQQxj/UliFYdF7v1vx+1680jHfrfj9rVcTGwveHMTvFfevnU3GzueREnDMPQiTNdevfkeR3oDyh/+NjlvwQ92OypKQHh5KaZGUpKnHwLlAkjiCusL4GZej1eNZw457/08sEpV8D3U+w23TYvR3dd37f4dE52qkkjRbTpueLVodMu/wzKwZPndWyINzdNsFp044wMfzgpLsjS3JzUqK8aE0YQt1hfMjzfG5V5jIH3CgnYjL73KvEafe8n4r3qfUWGPa//uKvP859vGPJ4vPIYksfrlccb+NFi6b8b56C7fVotfXfxtFr6vx5010/fnTsHnzN4/8B5Vlmt6r8rqGSzWvvOsw19vv95rSPfFVTqu1zou/Sy/09ndLH7SsD3F87I8F96KZkzJa4WLWJiOrt69d6p/nDS3LedGeLeKOlJCf7LO76AMi3dYcKIh0cYATTy+pIvFM3Ulz4R3etLPF5DnZd6db6rR+e6rv446HN3r85f6lVnV496fGHC45XXkK54vfJ6+z9G/f/t0SUgBA0KMZKueXGYmmrXHfNzdOf8abrj+hxNT08yY8iIQYZh6OS5SwOzJ83n9f6pzmve90uS8rOSAy7v3JKfaer9ZwgjwCAjrS+ZPz1NX/pEeNeXGIahS70enevq1Tl3j8539epcV8+gcDFEwOjqlety77CVwVCzWhQwk+B78zLfTMFQMxCjz0wMvJgHvKhfdY5hKHBGxRh499ghZ1/GM3tj9L0rreeq88bLZrWoZPYU3dl/6eXmmZlRP2WO2NHr8epwy4WAgHK0/eI1fy9sVouuz03vDyiZKi6Youump8kWof9WCSPAMCa6vuSKx6vz/bMVfaHi6mDRN1Nxrv/75y/17R9qkdpYpTsSlJWaqCkpdmWl2JWVnKgpKYnKSrFrSkqipqQO7LcnWANDQH84GCk82KyWqJ0dCrdrA9DwIcZ3aemK16uCqSnKSIrvG1ohuly43Kv3T3bq4KCA0nah+5rjUu023TKrL5j4AkpeZnhm8ggjwBiMtL7kzxZMl91mDZipONfVowuXr4zwiCNLtFn8AcIfJPoDhu/zzP6PvmOyUhKVGIfrWwBM3OnOS/6FsQeb+i7vdA1xy/vcDIf+74pi3TYvJ6TPTxgBgtR8tkuvHDyl3QdODbma/WoZSQn+GYkpKYnKSvYFDLumpF4dNvo+pthtcTsDAcB8Hq+hj9ou+O99cqDpvI60XpDXkP676i5dNz0tpM9HGAHGyTAMvXeyU//7UbscCTZ/kMgaNJuRmZwYl20cALGnq+eK3j/ZqU86p4Z83dNYX7+56QJwFYvF4n+TKwCIdSn2BJXNzTZ1DPzTDgAAmIowAgAATEUYAQAApiKMAAAAUxFGAACAqQgjAADAVIQRAABgKsIIAAAwFWEEAACYijACAABMRRgBAACmIowAAABTEUYAAICpJsW79hqGIanvrYgBAMDk4Hvd9r2OD2dShJELFy5IkgoKCkweCQAACNaFCxeUmZk57PctxmhxJQp4vV59/PHHSk9Pl8ViCdnjulwuFRQUqLm5WRkZGSF7XATi9xw5/K4jg99zZPB7joxw/p4Nw9CFCxc0c+ZMWa3DrwyZFDMjVqtVs2bNCtvjZ2Rk8B96BPB7jhx+15HB7zky+D1HRrh+zyPNiPiwgBUAAJiKMAIAAEwV12HE4XCourpaDofD7KHENH7PkcPvOjL4PUcGv+fIiIbf86RYwAoAAGJXXM+MAAAA8xFGAACAqQgjAADAVIQRAABgqrgOI9u2bZPT6VRSUpLKyspUV1dn9pBiSk1NjT75yU8qPT1d06dP1/Lly3X48GGzhxXz/vmf/1kWi0UPP/yw2UOJOadOndLXvvY1ZWdnKzk5Wbfccot+//vfmz2smOPxeLRhwwYVFhYqOTlZ8+bN07e+9a1R398EI3vzzTe1bNkyzZw5UxaLRT//+c8Dvm8YhjZu3KgZM2YoOTlZ5eXl+uijjyIytrgNI7t27VJVVZWqq6vV0NCgoqIiVVRUqK2tzeyhxYw33nhDa9as0e9+9zvt3btXvb29+tznPie322320GLWO++8o2effVa33nqr2UOJOefOndPSpUuVmJioX/3qV/rwww/11FNPacqUKWYPLeZ85zvf0TPPPKN//dd/1aFDh/Sd73xH3/3ud/X000+bPbRJze12q6ioSNu2bRvy+9/97nf1/e9/X9u3b9fbb7+t1NRUVVRU6PLly+EfnBGnFi9ebKxZs8b/tcfjMWbOnGnU1NSYOKrY1tbWZkgy3njjDbOHEpMuXLhgzJ8/39i7d69x1113GevWrTN7SDHlm9/8pnH77bebPYy4cO+99xpf//rXA/Z96UtfMu6//36TRhR7JBkvv/yy/2uv12vk5eUZ3/ve9/z7zp8/bzgcDuPHP/5x2McTlzMjPT09qq+vV3l5uX+f1WpVeXm59u/fb+LIYltnZ6ckaerUqSaPJDatWbNG9957b8B/1widX/ziFyotLdWXv/xlTZ8+XYsWLdKOHTvMHlZMuu2221RbW6sjR45Ikt59913t27dPn//8500eWew6fvy4WlpaAv5+ZGZmqqysLCKvi5PijfJCraOjQx6PR7m5uQH7c3Nz9cc//tGkUcU2r9erhx9+WEuXLtXNN99s9nBizksvvaSGhga98847Zg8lZh07dkzPPPOMqqqq9Nhjj+mdd97R3/7t38put2vVqlVmDy+mPProo3K5XFqwYIFsNps8Ho+efPJJ3X///WYPLWa1tLRI0pCvi77vhVNchhFE3po1a/TBBx9o3759Zg8l5jQ3N2vdunXau3evkpKSzB5OzPJ6vSotLdXmzZslSYsWLdIHH3yg7du3E0ZC7D//8z/1H//xH3rxxRd100036eDBg3r44Yc1c+ZMftcxKi4v0+Tk5Mhms6m1tTVgf2trq/Ly8kwaVexau3at/uu//ku//e1vNWvWLLOHE3Pq6+vV1tamT3ziE0pISFBCQoLeeOMNff/731dCQoI8Ho/ZQ4wJM2bM0MKFCwP23XjjjWpqajJpRLHr7//+7/Xoo4/qK1/5im655RY98MADeuSRR1RTU2P20GKW77XPrNfFuAwjdrtdJSUlqq2t9e/zer2qra3VkiVLTBxZbDEMQ2vXrtXLL7+s//mf/1FhYaHZQ4pJn/nMZ/T+++/r4MGD/q20tFT333+/Dh48KJvNZvYQY8LSpUuvqaYfOXJEc+bMMWlEsaurq0tWa+DLk81mk9frNWlEsa+wsFB5eXkBr4sul0tvv/12RF4X4/YyTVVVlVatWqXS0lItXrxYW7duldvtVmVlpdlDixlr1qzRiy++qFdeeUXp6en+646ZmZlKTk42eXSxIz09/Zp1OKmpqcrOzmZ9Tgg98sgjuu2227R582b9xV/8herq6vTcc8/pueeeM3toMWfZsmV68sknNXv2bN100006cOCAtmzZoq9//etmD21Su3jxoo4ePer/+vjx4zp48KCmTp2q2bNn6+GHH9a3v/1tzZ8/X4WFhdqwYYNmzpyp5cuXh39wYe/rRLGnn37amD17tmG3243Fixcbv/vd78weUkyRNOT2b//2b2YPLeZR7Q2PX/7yl8bNN99sOBwOY8GCBcZzzz1n9pBiksvlMtatW2fMnj3bSEpKMubOnWs8/vjjRnd3t9lDm9R++9vfDvk3edWqVYZh9NV7N2zYYOTm5hoOh8P4zGc+Yxw+fDgiY7MYBre0AwAA5onLNSMAACB6EEYAAICpCCMAAMBUhBEAAGAqwggAADAVYQQAAJiKMAIAAExFGAEAAKYijAAAAFMRRgAAgKkIIwAAwFSEEQAAYKr/D4CZx1SDO3QzAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(test_loss_hist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0126, 0.0725, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071, 0.7071,\n",
      "        0.7071, 1.0000, 1.0000], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "print(DYS_net.s_inv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "202\n"
     ]
    }
   ],
   "source": [
    "print(DYS_net.n2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[4.5900, 4.8700, 5.1900,  ..., 0.0000, 0.0000, 0.0000],\n",
      "        [5.1100, 3.1600, 7.0100,  ..., 0.0000, 0.0000, 0.0000],\n",
      "        [1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],\n",
      "        ...,\n",
      "        [0.0000, 0.0000, 0.0000,  ..., 1.0000, 0.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.0000, 0.0000],\n",
      "        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.0000]],\n",
      "       device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "print(DYS_net.A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[4.59 4.87 5.19 7.59 6.8  4.97 4.84 3.22 7.9  3.62 3.12 6.81 7.09 3.23\n",
      "  4.47 5.23 6.11 5.23 4.27 6.66 5.21 3.72 7.83 5.64 7.56 3.28 4.31 3.85\n",
      "  6.43 4.26 7.56 6.4  6.09 6.4  4.5  6.31 6.35 3.07 6.03 7.39 3.48 5.57\n",
      "  6.27 4.55 5.64 5.66 3.27 6.68 4.2  5.61 6.13 7.88 4.61 4.92 6.09 7.95\n",
      "  7.11 4.02 3.07 7.84 6.29 5.78 5.57 5.65 7.41 5.21 3.42 6.48 5.77 7.67\n",
      "  7.32 3.31 4.36 4.75 3.38 3.44 6.3  4.85 7.09 5.7  4.05 6.04 6.51 3.41\n",
      "  7.94 6.05 7.18 3.64 4.35 4.45 4.27 4.87 5.74 3.16 5.41 3.81 3.51 5.98\n",
      "  5.79 6.62]\n",
      " [5.11 3.16 7.01 6.48 6.36 4.12 7.9  7.69 3.68 4.73 5.85 5.03 3.69 7.13\n",
      "  7.11 5.72 5.8  4.59 6.85 4.14 6.18 4.28 5.38 5.79 6.78 3.79 4.61 3.06\n",
      "  7.78 3.94 6.06 7.22 5.18 3.44 3.93 5.56 3.56 6.59 5.   5.24 7.94 6.75\n",
      "  6.37 7.64 3.4  3.87 3.34 4.95 3.53 4.64 7.28 6.9  5.16 4.49 5.04 7.82\n",
      "  7.41 6.92 7.45 7.29 6.31 6.6  4.66 4.07 7.25 5.58 5.33 7.82 7.37 6.23\n",
      "  4.12 7.24 6.1  6.26 6.09 4.87 6.12 7.2  7.48 4.87 6.66 5.1  6.17 4.69\n",
      "  4.41 4.04 6.   3.71 6.82 5.15 6.95 3.47 4.16 5.17 7.97 6.08 3.87 7.03\n",
      "  3.88 3.13]]\n"
     ]
    }
   ],
   "source": [
    "print(weights_numpy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3. 5. 7. 3. 8. 5. 5. 7. 4. 6. 6. 3. 3. 5. 7. 6. 7. 6. 4. 6. 5. 4. 6. 6.\n",
      " 6. 4. 8. 7. 4. 6. 6. 7. 6. 5. 6. 6. 7. 8. 3. 6. 6. 6. 4. 5. 6. 3. 5. 6.\n",
      " 8. 3. 4. 7. 7. 5. 4. 7. 4. 8. 8. 4. 5. 3. 5. 5. 6. 4. 9. 4. 6. 5. 6. 6.\n",
      " 7. 7. 5. 7. 7. 4. 7. 7. 8. 6. 5. 7. 7. 9. 5. 5. 7. 5. 7. 7. 7. 6. 7. 8.\n",
      " 6. 6. 5. 5.]\n"
     ]
    }
   ],
   "source": [
    "print(costs_numpy[0,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([18., 18.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
      "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
      "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
      "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
      "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
      "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
      "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
      "         1.,  1.,  1.,  1.], device='mps:0')\n"
     ]
    }
   ],
   "source": [
    "print(DYS_net.b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(80.3926, device='mps:0')"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Preparing the benchmark approaches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "optmodel = pyepo.model.grb.knapsackModel(weights_numpy, capacities_numpy)\n",
    "val_net = ValPredictNet(num_constraints, num_item, num_feat, device=device)\n",
    "val_net.to(device)\n",
    "max_epochs = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Train a two-stage approach\n",
    "optimizer = optim.Adam(val_net.parameters(), lr=1e-2)\n",
    "criterion = nn.MSELoss()\n",
    "scheduler = ReduceLROnPlateau(optimizer, 'min')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:  0 test loss is  6.611246109008789\n",
      "epoch:  1 test loss is  3.933697462081909\n",
      "epoch:  2 test loss is  3.5205788612365723\n",
      "epoch:  3 test loss is  3.3509225845336914\n",
      "epoch:  4 test loss is  3.2458724975585938\n",
      "epoch:  5 test loss is  3.165217399597168\n",
      "epoch:  6 test loss is  3.1080434322357178\n",
      "epoch:  7 test loss is  3.0606706142425537\n",
      "epoch:  8 test loss is  3.0264034271240234\n",
      "epoch:  9 test loss is  2.99825119972229\n",
      "epoch:  10 test loss is  2.9851582050323486\n"
     ]
    }
   ],
   "source": [
    "epoch = 0\n",
    "while epoch <= max_epochs:\n",
    "    for d_batch, w_batch, opt_sol, opt_value in loader_train:\n",
    "        d_batch = d_batch.to(device)\n",
    "        w_batch = w_batch.to(device)\n",
    "        opt_sol = opt_sol.to(device)\n",
    "        opt_value = opt_value.to(device)\n",
    "        val_net.train()\n",
    "        optimizer.zero_grad()\n",
    "        w_predicted = val_net(d_batch)\n",
    "        loss = criterion(w_batch, w_predicted)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss_ave = 0.95*train_loss_ave + 0.05*loss.item()\n",
    "        \n",
    "    val_net.eval()\n",
    "    test_loss = 0\n",
    "    for d_batch, w_batch, opt_sol, opt_value in loader_test:\n",
    "        d_batch = d_batch.to(device)\n",
    "        w_batch = w_batch.to(device)\n",
    "        opt_sol = opt_sol.to(device)\n",
    "        opt_value = opt_value.to(device)\n",
    "        w_predicted = val_net(d_batch)\n",
    "        test_loss += criterion(w_batch, w_predicted)\n",
    "    \n",
    "    # print(predicted)\n",
    "    scheduler.step(test_loss)\n",
    "    test_loss_hist.append(test_loss)\n",
    "    print('epoch: ', epoch, 'test loss is ', test_loss.item())\n",
    "    epoch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimal value is  [38.]  while predicted value is  34.0\n",
      "Optimal value is  [42.]  while predicted value is  37.0\n",
      "Optimal value is  [33.]  while predicted value is  31.0\n",
      "Optimal value is  [40.]  while predicted value is  26.0\n",
      "Optimal value is  [53.]  while predicted value is  41.0\n",
      "Optimal value is  [39.]  while predicted value is  31.0\n",
      "Optimal value is  [42.]  while predicted value is  37.0\n",
      "Optimal value is  [32.]  while predicted value is  21.0\n",
      "Optimal value is  [55.]  while predicted value is  47.0\n",
      "Optimal value is  [50.]  while predicted value is  48.0\n",
      "Optimal value is  [42.]  while predicted value is  37.0\n",
      "Optimal value is  [33.]  while predicted value is  25.0\n",
      "Optimal value is  [41.]  while predicted value is  36.0\n",
      "Optimal value is  [40.]  while predicted value is  31.0\n",
      "Optimal value is  [36.]  while predicted value is  34.0\n",
      "Optimal value is  [50.]  while predicted value is  44.0\n",
      "Optimal value is  [41.]  while predicted value is  30.0\n",
      "Optimal value is  [46.]  while predicted value is  44.0\n",
      "Optimal value is  [35.]  while predicted value is  24.0\n",
      "Optimal value is  [37.]  while predicted value is  34.0\n",
      "Optimal value is  [39.]  while predicted value is  33.0\n",
      "Optimal value is  [40.]  while predicted value is  38.0\n",
      "Optimal value is  [36.]  while predicted value is  33.0\n",
      "Optimal value is  [37.]  while predicted value is  29.0\n",
      "Optimal value is  [47.]  while predicted value is  36.0\n",
      "Optimal value is  [59.]  while predicted value is  50.0\n",
      "Optimal value is  [60.]  while predicted value is  54.0\n",
      "Optimal value is  [42.]  while predicted value is  42.0\n",
      "Optimal value is  [56.]  while predicted value is  44.0\n",
      "Optimal value is  [32.]  while predicted value is  24.0\n",
      "Optimal value is  [44.]  while predicted value is  40.0\n",
      "Optimal value is  [45.]  while predicted value is  35.0\n",
      "Optimal value is  [44.]  while predicted value is  36.0\n",
      "Optimal value is  [50.]  while predicted value is  41.0\n",
      "Optimal value is  [41.]  while predicted value is  37.0\n",
      "Optimal value is  [53.]  while predicted value is  46.0\n",
      "Optimal value is  [45.]  while predicted value is  36.0\n",
      "Optimal value is  [47.]  while predicted value is  36.0\n",
      "Optimal value is  [37.]  while predicted value is  34.0\n",
      "Optimal value is  [42.]  while predicted value is  37.0\n",
      "Optimal value is  [35.]  while predicted value is  30.0\n",
      "Optimal value is  [45.]  while predicted value is  33.0\n",
      "Optimal value is  [48.]  while predicted value is  38.0\n",
      "Optimal value is  [36.]  while predicted value is  28.0\n",
      "Optimal value is  [38.]  while predicted value is  34.0\n",
      "Optimal value is  [40.]  while predicted value is  33.0\n",
      "Optimal value is  [45.]  while predicted value is  38.0\n",
      "Optimal value is  [50.]  while predicted value is  44.0\n",
      "Optimal value is  [42.]  while predicted value is  30.0\n",
      "Optimal value is  [38.]  while predicted value is  25.0\n",
      "Optimal value is  [34.]  while predicted value is  30.0\n",
      "Optimal value is  [36.]  while predicted value is  29.0\n",
      "Optimal value is  [49.]  while predicted value is  45.0\n",
      "Optimal value is  [42.]  while predicted value is  37.0\n",
      "Optimal value is  [38.]  while predicted value is  32.0\n",
      "Optimal value is  [34.]  while predicted value is  30.0\n",
      "Optimal value is  [36.]  while predicted value is  32.0\n",
      "Optimal value is  [43.]  while predicted value is  35.0\n",
      "Optimal value is  [37.]  while predicted value is  30.0\n",
      "Optimal value is  [45.]  while predicted value is  30.0\n",
      "Optimal value is  [36.]  while predicted value is  21.0\n",
      "Optimal value is  [45.]  while predicted value is  37.0\n",
      "Optimal value is  [35.]  while predicted value is  28.0\n",
      "Optimal value is  [35.]  while predicted value is  20.0\n",
      "Optimal value is  [38.]  while predicted value is  37.0\n",
      "Optimal value is  [40.]  while predicted value is  32.0\n",
      "Optimal value is  [51.]  while predicted value is  37.0\n",
      "Optimal value is  [38.]  while predicted value is  34.0\n",
      "Optimal value is  [46.]  while predicted value is  42.0\n",
      "Optimal value is  [34.]  while predicted value is  30.0\n",
      "Optimal value is  [40.]  while predicted value is  33.0\n",
      "Optimal value is  [41.]  while predicted value is  37.0\n",
      "Optimal value is  [48.]  while predicted value is  31.0\n",
      "Optimal value is  [46.]  while predicted value is  41.0\n",
      "Optimal value is  [38.]  while predicted value is  29.0\n",
      "Optimal value is  [39.]  while predicted value is  35.0\n",
      "Optimal value is  [48.]  while predicted value is  42.0\n",
      "Optimal value is  [40.]  while predicted value is  30.0\n",
      "Optimal value is  [40.]  while predicted value is  37.0\n",
      "Optimal value is  [44.]  while predicted value is  39.0\n",
      "Optimal value is  [41.]  while predicted value is  26.0\n",
      "Optimal value is  [37.]  while predicted value is  30.0\n",
      "Optimal value is  [42.]  while predicted value is  41.0\n",
      "Optimal value is  [44.]  while predicted value is  38.0\n",
      "Optimal value is  [44.]  while predicted value is  32.0\n",
      "Optimal value is  [32.]  while predicted value is  28.0\n",
      "Optimal value is  [42.]  while predicted value is  40.0\n",
      "Optimal value is  [44.]  while predicted value is  37.0\n",
      "Optimal value is  [43.]  while predicted value is  37.0\n",
      "Optimal value is  [44.]  while predicted value is  34.0\n",
      "Optimal value is  [44.]  while predicted value is  38.0\n",
      "Optimal value is  [38.]  while predicted value is  31.0\n",
      "Optimal value is  [44.]  while predicted value is  28.0\n",
      "Optimal value is  [35.]  while predicted value is  26.0\n",
      "Optimal value is  [52.]  while predicted value is  47.0\n",
      "Optimal value is  [38.]  while predicted value is  30.0\n",
      "Optimal value is  [39.]  while predicted value is  32.0\n",
      "Optimal value is  [35.]  while predicted value is  26.0\n",
      "Optimal value is  [32.]  while predicted value is  31.0\n",
      "Optimal value is  [32.]  while predicted value is  26.0\n"
     ]
    }
   ],
   "source": [
    "# compare solutions\n",
    "for d_batch, w_batch, opt_sol, opt_value in loader_test:\n",
    "    # load data\n",
    "    # convert to numpy\n",
    "    d_batch = d_batch.to(device)\n",
    "    # d_batch = d_batch.to(\"cpu\").detach().numpy()\n",
    "    w_batch = w_batch.to(\"cpu\").detach().numpy()\n",
    "    opt_sol = opt_sol.to(\"cpu\").detach().numpy()\n",
    "    opt_value = opt_value.to(\"cpu\").detach().numpy()\n",
    "    # predict\n",
    "    w_predicted = val_net(d_batch)\n",
    "    # print(w_predicted)\n",
    "    batch_size = d_batch.shape[0]\n",
    "    regret = 0\n",
    "    for i in range(batch_size):\n",
    "        optmodel.setObj(w_predicted[i])\n",
    "        pred_sol, pred_val = optmodel.solve()\n",
    "        regret += np.max(np.squeeze(opt_value[i]) - np.dot(w_batch[i], pred_sol),0)\n",
    "        # print((opt_value[i] - pred_val)/opt_value[i])\n",
    "        print('Optimal value is ', opt_value[i], ' while predicted value is ', np.dot(w_batch[i], pred_sol))\n",
    "#     if i == ind:\n",
    "#         # solve\n",
    "#         optmodel.setObj(cp[0])\n",
    "#         wp, _ = optmodel.solve()\n",
    "#         fig = plotSol(m, c, wp, weights, caps, \"Two-Stage with Linear Regression\")\n",
    "#         break\n",
    "\n",
    "    normalized_regret = regret/np.sum(opt_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.17083734359961503\n"
     ]
    }
   ],
   "source": [
    "print(normalized_regret)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num of cores: 1\n"
     ]
    }
   ],
   "source": [
    "## SPO+ Loss\n",
    "# initialize value-prediction network\n",
    "device = 'cpu'\n",
    "val_net2 = ValPredictNet(num_constraints, num_item, num_feat, device=device)\n",
    "val_net2.to(device)\n",
    "# set optimizer\n",
    "optimizer = torch.optim.Adam(val_net2.parameters(), lr=1e-2)\n",
    "# init SPO+ loss\n",
    "criterion = pyepo.func.SPOPlus(optmodel, processes=1)\n",
    "\n",
    "\n",
    "# # train model\n",
    "# num_epochs = 20\n",
    "# reg.train()\n",
    "# for epoch in range(num_epochs):\n",
    "#     # load data\n",
    "#     for i, data in enumerate(loader_train):\n",
    "#         x, c, w, z = data\n",
    "#         # cudaa\n",
    "#         if torch.cuda.is_available():\n",
    "#             x, c, w, z = x.cuda(), c.cuda(), w.cuda(), z.cuda()\n",
    "#         # forward pass\n",
    "#         cp = reg(x)\n",
    "#         # spo+ loss\n",
    "#         loss = spop(cp, c, w, z).mean()\n",
    "#         # backward pass\n",
    "#         optimizer.zero_grad()\n",
    "#         loss.backward()\n",
    "#         optimizer.step()\n",
    "#     if epoch % 5 == 0:\n",
    "#         # regret\n",
    "#         regret = pyepo.metric.regret(reg, optmodel, loader_test)\n",
    "#         print(\"Epoch {:3}, Loss: {:8.4f}, Regret: {:7.4f}%\".format(epoch, loss.item(), regret*100))\n",
    "# # regret\n",
    "# regret = pyepo.metric.regret(reg, optmodel, loader_test)\n",
    "# print(\"Epoch {:3}, Loss: {:8.4f}, Regret: {:7.4f}%\".format(epoch, loss.item(), regret*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:  0 test loss is  21.098979949951172\n",
      "epoch:  1 test loss is  19.743213653564453\n",
      "epoch:  2 test loss is  19.175342559814453\n",
      "epoch:  3 test loss is  18.501697540283203\n",
      "epoch:  4 test loss is  18.271207809448242\n",
      "epoch:  5 test loss is  18.415706634521484\n",
      "epoch:  6 test loss is  18.457067489624023\n",
      "epoch:  7 test loss is  18.404882431030273\n",
      "epoch:  8 test loss is  18.315059661865234\n",
      "epoch:  9 test loss is  18.289501190185547\n",
      "epoch:  10 test loss is  18.100351333618164\n"
     ]
    }
   ],
   "source": [
    "# train model\n",
    "epoch = 0\n",
    "while epoch <= max_epochs:\n",
    "    for d_batch, w_batch, opt_sol, opt_value in loader_train:\n",
    "        d_batch = d_batch.to(device)\n",
    "        w_batch = w_batch.to(device)\n",
    "        opt_sol = opt_sol.to(device)\n",
    "        opt_value = opt_value.to(device)\n",
    "        val_net2.train()\n",
    "        optimizer.zero_grad()\n",
    "        w_predicted = val_net2(d_batch)\n",
    "        loss = criterion(w_predicted, w_batch, opt_sol, opt_value).mean()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss_ave = 0.95*train_loss_ave + 0.05*loss.item()\n",
    "        \n",
    "    val_net2.eval()\n",
    "    test_loss = 0\n",
    "    for d_batch, w_batch, opt_sol, opt_value in loader_test:\n",
    "        d_batch = d_batch.to(device)\n",
    "        w_batch = w_batch.to(device)\n",
    "        opt_sol = opt_sol.to(device)\n",
    "        opt_value = opt_value.to(device)\n",
    "        w_predicted = val_net2(d_batch)\n",
    "        #print(w_predicted[0])\n",
    "        test_loss += criterion(w_predicted, w_batch, opt_sol, opt_value).mean()\n",
    "    \n",
    "    # print(predicted)\n",
    "    #scheduler.step(test_loss)\n",
    "    test_loss_hist.append(test_loss)\n",
    "    print('epoch: ', epoch, 'test loss is ', test_loss.item())\n",
    "    epoch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
